{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 28、Vision Transformer（ViT）模型原理及PyTorch逐行实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "论文地址：<https://arxiv.org/abs/2010.11929>\n",
    "## ViT 想法\n",
    "ViT想要把Transformer模型应用到图像识别领域，对于nlp任务，单位是一个字（token），直接应用到像素点上，序列长度会很高，元素信息量太少，因此以区域作为单位建模。\n",
    "![](http://assets.hypervoid.top/img/2025/07/05/202507051748682-b9aa.png)\n",
    "### DNN角度\n",
    "\n",
    "1. 对图片切分成一个个块（image to patch）\n",
    "2. 对patch经过仿射变换来变成embedding\n",
    "\n",
    "### CNN角度\n",
    "\n",
    "将图片变成embedding可以看成卷积过程，其中 `kernel_size=stride` ，卷积后将特征图拉直。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 借鉴\n",
    "\n",
    "### BERT：class token embedding\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## position embedding\n",
    "对比后发现一维可训练embedding效果好\n",
    "\n",
    "## Encoder\n",
    "只使用了encoder模块，没有使用decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分别使用 unfold 和卷积操作实现 image->embedding 操作\n",
    "def image2emb_naive(img: Tensor, weight, patch_size=16):\n",
    "    assert img.dim() == 4\n",
    "    batch_size, channel, img_width, img_height = img.shape\n",
    "    # shape: batch_size, patched, patch_num\n",
    "    region = F.unfold(img, patch_size, stride=patch_size)\n",
    "    region.transpose_(-1, -2)\n",
    "    print(region.shape)\n",
    "    return region @ weight\n",
    "\n",
    "\n",
    "def image2emb_conv(img: Tensor, kernel, patch_size=16):\n",
    "    conv = F.conv2d(img, kernel, stride=patch_size)\n",
    "    bs, oc, ow, oh = conv.shape\n",
    "    patch_emb = conv.reshape((bs, oc, ow*oh)).transpose(-1, -2) # (bs, ow*oh, oc)\n",
    "    return patch_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 196])\n",
      "emb.shape= torch.Size([2, 4, 8])\n",
      "tensor([[[  0.5155,  -0.3687,  -3.6338,  -1.2878,   1.2945,   1.8334,  -3.7690,\n",
      "           -0.6200],\n",
      "         [  7.3336, -15.6543,   6.0121,   4.6137,   4.7995, -10.7871,   0.3487,\n",
      "            6.9605],\n",
      "         [ -1.3895,   1.5816,  -2.5958,   3.9162,  -5.0130,   2.0175,   6.0092,\n",
      "            3.8195],\n",
      "         [  3.8014,  -2.9465,  10.2140,   9.6521,   4.3162, -10.4472,  11.1373,\n",
      "            8.1564]],\n",
      "\n",
      "        [[ -0.9922, -11.4510,  -7.3893,   0.5645,  -0.6928,   0.1476,  -5.3571,\n",
      "            0.7916],\n",
      "         [  9.9496,  -5.7659,   4.8640,   3.0778,   2.8867, -14.5847,   5.3720,\n",
      "           16.2563],\n",
      "         [ -2.6126,  -3.1197,  -7.0267,  -8.0693,   0.0918,   6.7568,   6.4510,\n",
      "            3.3721],\n",
      "         [  7.1353,  -5.7266,  12.7986,   9.8556,   5.0186, -14.1588,   5.5587,\n",
      "           10.3287]]])\n",
      "emb.shape= torch.Size([2, 4, 8])\n",
      "tensor([[[  0.5155,  -0.3687,  -3.6338,  -1.2878,   1.2945,   1.8334,  -3.7690,\n",
      "           -0.6200],\n",
      "         [  7.3336, -15.6543,   6.0121,   4.6137,   4.7995, -10.7871,   0.3487,\n",
      "            6.9605],\n",
      "         [ -1.3895,   1.5816,  -2.5958,   3.9162,  -5.0130,   2.0175,   6.0092,\n",
      "            3.8195],\n",
      "         [  3.8014,  -2.9465,  10.2140,   9.6521,   4.3162, -10.4471,  11.1373,\n",
      "            8.1564]],\n",
      "\n",
      "        [[ -0.9922, -11.4510,  -7.3893,   0.5645,  -0.6928,   0.1476,  -5.3571,\n",
      "            0.7916],\n",
      "         [  9.9496,  -5.7659,   4.8640,   3.0778,   2.8867, -14.5847,   5.3719,\n",
      "           16.2563],\n",
      "         [ -2.6126,  -3.1197,  -7.0267,  -8.0693,   0.0918,   6.7568,   6.4510,\n",
      "            3.3721],\n",
      "         [  7.1353,  -5.7266,  12.7986,   9.8556,   5.0186, -14.1588,   5.5587,\n",
      "           10.3287]]])\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "ds = torchvision.datasets.mnist.FashionMNIST(\"./fashion_minist_dataset\", True, download=True, transform=transforms.ToTensor())\n",
    "dl = DataLoader(ds, 2, shuffle=False)\n",
    "\n",
    "patch_size, model_dim = 14, 8\n",
    "for img, _ in dl:\n",
    "    batch_size, channel, img_width, img_height = img.shape\n",
    "    weight = torch.randn(patch_size**2 * channel, model_dim)\n",
    "    patch_token_emb\\ = image2emb_naive(img, weight, 14)\n",
    "    print(\"emb.shape=\", patch_token_emb\\.shape)\n",
    "    print(patch_token_emb\\)\n",
    "    # kernel = torch.randn((model_dim, 1, patch_size, patch_size))\n",
    "    kernel = weight.transpose(0, 1).reshape((-1, channel, patch_size, patch_size))\n",
    "    emb2 = image2emb_conv(img, kernel, 14)\n",
    "    print(\"emb.shape=\", emb2.shape)\n",
    "    print(emb2)\n",
    "    assert torch.allclose(patch_token_emb\\, emb2)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加 classification token 用于分类任务\n",
    "\n",
    "patch_size, model_dim = 14, 8\n",
    "for img, _ in dl:\n",
    "    batch_size, channel, img_width, img_height = img.shape\n",
    "    weight = torch.randn(patch_size**2 * channel, model_dim)\n",
    "    patch_token_emb = image2emb_naive(img, weight, 14)\n",
    "    # 2. 增加 cls token embedding\n",
    "    cls_token_emb = torch.randn((batch_size, 1, model_dim), requires_grad=True)\n",
    "    emb = torch.cat([cls_token_emb, patch_token_emb], dim=1)\n",
    "    # 3. 增加位置编码\n",
    "    pos_emb_table = torch.randn(32, model_dim, requires_grad=True) # 32表示最大token数量\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
